#include <gtest/gtest.h>
#include <mockcpp/mockcpp.hpp>

#define protected public
#include "common/types.h"
#include "infra/base/securestl.h"
#include "graph/op/array_defs.h"
#include "graph/op/const_defs.h"
#include "graph/op/math_defs.h"
#include "graph/op/nn_defs.h"
#include "framework/graph/debug/ge_op_types.h"
#include "framework/graph/utils/graph_utils.h"
#include "framework/graph/utils/tensor_utils.h"
#include "framework/graph/utils/op_desc_utils.h"
#include "framework/graph/utils/attr_utils.h"
#include "framework/graph/core/cgraph/graph_finder.h"
#include "framework/graph/core/cgraph/graph_spec.h"
#include "framework/graph/core/cgraph/graph_modifier.h"
#include "framework/graph/core/edge/edge.h"
#include "framework/graph/core/optimize/optimize_status.h"
#define private public
#include "framework/graph/core/optimize/schedule/graph_pass.h"
#include "framework/graph/core/optimize/schedule/graph_pass_list.h"
#include "framework/graph/core/optimize/schedule/graph_optimize_scheduler.h"
#undef protected
#undef private

using namespace std;
using namespace hiai;
using namespace ge;

class UTEST_graph_pass_list : public testing::Test {
protected:
    void SetUp()
    {
    }
    void TearDown()
    {
        GlobalMockObject::verify();
    }

protected:
    Node& AddNode(ComputeGraphPtr graph, const string& name, const string& type, int32_t out_anchors_num = 1,
        int32_t in_anchors_num = 1)
    {
        TensorDesc tensor_desc;
        OpDescPtr opdesc = hiai::make_shared_nothrow<OpDesc>(name, type);
        for (int32_t i = 0; i < in_anchors_num; i++) {
            opdesc->AddInputDesc(tensor_desc);
        }
        for (int32_t i = 0; i < out_anchors_num; i++) {
            opdesc->AddOutputDesc(tensor_desc);
        }

        return *graph->ROLE(GraphModifier).AddNode(opdesc);
    }

    ComputeGraphPtr CreateGraph()
    {
        ComputeGraphPtr graph = ge::ComputeGraph::Make("test_graph");

        Node& any = AddNode(graph, "op_any", hiai::op::Reshape::TYPE);
        Node& greater = AddNode(graph, "op_greater", hiai::op::Greater::TYPE, 1, 2);
        Node& cast = AddNode(graph, "op_cast", hiai::op::CastT::TYPE);
        Node& mul = AddNode(graph, "op_mul", hiai::op::Mul::TYPE, 1, 2);
        Node& relu = AddNode(graph, "op_relu", RELU);

        graph->ROLE(GraphModifier).AddEdge({any, 0}, {greater, 0});
        graph->ROLE(GraphModifier).AddEdge({any, 0}, {mul, 1});
        graph->ROLE(GraphModifier).AddEdge({greater, 0}, {cast, 0});
        graph->ROLE(GraphModifier).AddEdge({cast, 0}, {mul, 0});
        graph->ROLE(GraphModifier).AddEdge({mul, 0}, {relu, 0});

        return graph;
    }
};

class TestPass : public GraphPass {
public:
    OptimizeStatus Run(ge::ComputeGraph& graph) override
    {
        return OptimizeStatus::OPTIMIZE_CHANGED;
    }
};

TEST_F(UTEST_graph_pass_list, test_graph_pass)
{
    ComputeGraphPtr graph = CreateGraph();
    auto passList = hiai::GraphPassList::Make();
    EXPECT_NE(passList, nullptr);
    std::unique_ptr<GraphPass> testpass = hiai::make_unique_nothrow<TestPass>();
    EXPECT_NE(testpass, nullptr);
    passList->Add(std::move(testpass));
    passList->Optimize(*graph);
    auto graphPassScheduler = hiai::GraphOptimizeScheduler::Make(*passList, 1);
    EXPECT_NE(graphPassScheduler, nullptr);
    Status ret = graphPassScheduler->Schedule(*graph);
    EXPECT_EQ(ret, hiai::SUCCESS);
}